"""

This example learns to classify hand-written digits from the
[MNIST](http://yann.lecun.com/exdb/mnist) dataset.  There are 60000
training and 10000 test examples. Each input x consists of 784 pixels
representing a 28x28 image.  The pixel values are normalized to
[0,1]. Each output y is converted to a ten-dimensional one-hot vector
(a vector that has a single non-zero component) indicating the correct
class (0-9) for a given image.  10 is used to represent 0.

You can run the demo using `julia mlp.jl` on the command line or
`julia> MLP.main()` at the Julia prompt.  Options can be used like
`julia mlp.jl --epochs 3` or `julia> MLP.main("--epochs 3")`.  Use
`julia mlp.jl --help` for a list of options.  The dataset will be
automatically downloaded.  By default a softmax model will be trained
for 10 epochs.  You can also train a multi-layer perceptron by
specifying one or more --hidden sizes.  The accuracy for the training
and test sets will be printed at every epoch and optimized parameters
will be returned.

"""
module MLP
using Knet,CUDA,MLDatasets,ArgParse

function predict(w,x)
    for i=1:2:length(w)
        x = w[i]*mat(x) .+ w[i+1]
        if i<length(w)-1
            x = relu.(x) # max(0,x)
        end
    end
    return x
end

loss(w,x,ygold) = nll(predict(w,x),ygold)

lossgradient = grad(loss)

function train(w, dtrn; lr=.5, epochs=10)
    for epoch=1:epochs
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            update!(w,g;lr=lr)
        end
    end
    return w
end

function weights(h...; atype=Array{Float32}, winit=0.1)
    w = Any[]
    x = 28*28
    for y in [h..., 10]
        push!(w, convert(atype, winit*randn(y,x)))
        push!(w, convert(atype, zeros(y, 1)))
        x = y
    end
    return w
end

function main(args="")
    s = ArgParseSettings()
    s.description="mlp.jl (c) Deniz Yuret, 2016. Multi-layer perceptron model on the MNIST handwritten digit recognition problem from http://yann.lecun.com/exdb/mnist."
    s.exc_handler=ArgParse.debug_handler
    @add_arg_table s begin
        ("--seed"; arg_type=Int; default=-1; help="random number seed: use a nonnegative int for repeatable results")
        ("--batchsize"; arg_type=Int; default=100; help="minibatch size")
        ("--epochs"; arg_type=Int; default=10; help="number of epochs for training")
        ("--hidden"; nargs='*'; arg_type=Int; help="sizes of hidden layers, e.g. --hidden 128 64 for a net with two hidden layers")
        ("--lr"; arg_type=Float64; default=0.5; help="learning rate")
        ("--winit"; arg_type=Float64; default=0.1; help="w initialized with winit*randn()")
        ("--fast"; action=:store_true; help="skip loss printing for faster run")
        ("--atype"; default="$(Knet.array_type[])"; help="array type: Array for cpu, KnetArray for gpu")
        ("--gcheck"; arg_type=Int; default=0; help="check N random gradients per parameter")
        # These are to experiment with sparse arrays
        # ("--xtype"; help="input array type: defaults to atype")
        # ("--ytype"; help="output array type: defaults to atype")
    end
    isa(args, AbstractString) && (args=split(args))
    if in("--help", args) || in("-h", args)
        ArgParse.show_help(s; exit_when_done=false)
        return
    end
    o = parse_args(args, s; as_symbols=true)
    if !o[:fast]
        println(s.description)
        println("opts=",[(k,v) for (k,v) in o]...)
    end
    o[:seed] > 0 && Knet.seed!(o[:seed])
    atype = eval(Meta.parse(o[:atype]))
    w = weights(o[:hidden]...; atype=atype, winit=o[:winit])
    xtrn,ytrn = MNIST.traindata(Float32); ytrn[ytrn .== 0] .= 10
    xtst,ytst = MNIST.testdata(Float32);  ytst[ytst .== 0] .= 10
    xsize = (size(xtrn,1),size(xtrn,2),1,o[:batchsize])
    global dtrn = minibatch(xtrn, ytrn, o[:batchsize]; xtype=atype,xsize=xsize)
    global dtst = minibatch(xtst, ytst, o[:batchsize]; xtype=atype,xsize=xsize)
    report(epoch)=println((:epoch,epoch,:trn,accuracy(x->predict(w,x); data=dtrn),:tst,accuracy(x->predict(w,x); data=dtst)))
    if o[:fast]
        @time (train(w, dtrn; lr=o[:lr], epochs=o[:epochs]); CUDA.functional() && CUDA.synchronize())
    else
        report(0)
        @time for epoch=1:o[:epochs]
            train(w, dtrn; lr=o[:lr], epochs=1)
            report(epoch)
            if o[:gcheck] > 0
                gradcheck(loss, w, first(dtrn)...; gcheck=o[:gcheck], verbose=true)
            end
        end
    end
    return w
end

# This allows both non-interactive (shell command) and interactive calls like:
# $ julia mlp.jl --epochs 10
# julia> MLP.main("--epochs 10")
PROGRAM_FILE == "mlp.jl" && main(ARGS)

end # module

# SAMPLE RUN 65f57ff+ Wed Sep 14 10:02:30 EEST 2016
#
# mnist2d.jl (c) Deniz Yuret, 2016. Multi-layer perceptron model on the MNIST handwritten digit recognition problem from http://yann.lecun.com/exdb/mnist.
# opts=(:seed,-1)(:batchsize,100)(:hidden,Int64[])(:epochs,10)(:lr,0.5)(:atype,"KnetArray{Float32}")(:gcheck,0)(:winit,0.1)(:fast,true)
# (:epoch,0,:trn,0.079066664f0,:tst,0.0842f0)
#   2.168927 seconds (2.95 M allocations: 115.993 MB, 1.84% gc time)
# (:epoch,10,:trn,0.9195333f0,:tst,0.9158f0)
